#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <tuple>
#include <sys/cdefs.h>
#include "linalg.hpp"
#include "esmd_types.h"
#include "rigid.hpp"

struct shake {
  int maxiter = 200;
  real tol = 1e-6;
  template <typename... Ts>
  static __always_inline void eval_all(Ts... args) {
  }
  template <class T, int IC, int JC>
  real mass_factor(real *invmass) {
    constexpr int l0 = T::ijs[IC * 2];
    constexpr int r0 = T::ijs[IC * 2 + 1];
    constexpr int l1 = T::ijs[JC * 2];
    constexpr int r1 = T::ijs[JC * 2 + 1];
    constexpr int delta11 = l0 == l1;
    constexpr int delta12 = l0 == r1;
    constexpr int delta21 = r0 == l1;
    constexpr int delta22 = r0 == r1;
    return (delta11 - delta12) * invmass[l0] + (delta22 - delta21) * invmass[r0];
  }
  template <class T, int IJKC>
  real qmat_factor(real *invmass, real *rdot) {
    constexpr int kc = IJKC / T::ncons / T::ncons;
    constexpr int ic = IJKC / T::ncons % T::ncons;
    constexpr int jc = IJKC % T::ncons;
    constexpr int lk = T::ijs[kc * 2];
    constexpr int rk = T::ijs[kc * 2 + 1];
    constexpr int li = T::ijs[ic * 2];
    constexpr int ri = T::ijs[ic * 2 + 1];
    constexpr int lj = T::ijs[jc * 2];
    constexpr int rj = T::ijs[jc * 2 + 1];
    constexpr int deltai11 = lk == li;
    constexpr int deltai12 = lk == ri;
    constexpr int deltai21 = rk == li;
    constexpr int deltai22 = rk == ri;
    constexpr int deltaj11 = lk == lj;
    constexpr int deltaj12 = lk == rj;
    constexpr int deltaj21 = rk == lj;
    constexpr int deltaj22 = rk == rj;
    real lmassterm = (deltai11 - deltai12) * invmass[lk] + (deltai22 - deltai21) * invmass[rk];
    real rmassterm = (deltaj11 - deltaj12) * invmass[lk] + (deltaj22 - deltaj21) * invmass[rk];
    return lmassterm * rmassterm * rdot[ic * T::ncons + jc];
  }
  template <int I, int... Is>
  static constexpr int first() {
    return I;
  }
  template <class T, int... IAs, int... Is, int... Js, int... ICs, int... ICMs, int... JCMs, int... IJKCs>
  __always_inline void do_shake_static(vec<real> *ox, vec<real> *oxp, real *oinvmass, rigid_rec &rec,
                                       int *idx,
                                       utils::iseq<int, IAs...> iaseq,
                                       utils::iseq<int, Is...> iseq,
                                       utils::iseq<int, Js...> jseq,
                                       utils::iseq<int, ICs...> icseq,
                                       utils::iseq<int, ICMs...> icmatseq,
                                       utils::iseq<int, JCMs...> jcmatseq,
                                       utils::iseq<int, IJKCs...> ijkcseq) {
    vec<real> x[] = {ox[idx[IAs]]...};
    vec<real> xp[] = {oxp[idx[IAs]]...};
    real invmass[] = {oinvmass[idx[IAs]]...};
    vec<real> r[] = {x[Is] - x[Js]...};
    vec<real> ruc[] = {xp[Is] - xp[Js]...};
    real r0sq[] = {rec.r0[ICs] * rec.r0[ICs]...};
    

    // printmat<3>(qmat);
    real l[T::ncons];
    if (T::ncons == 1) {
      constexpr int i = first<Is...>();
      constexpr int j = first<Js...>();
      real a = (invmass[i] + invmass[j]) * (invmass[i] + invmass[j]) * r[0].norm2();
      real b = 2.0 * (invmass[i] + invmass[j]) * r[0].dot(ruc[0]);
      real c = ruc[0].norm2() - r0sq[0];
      real det = b * b - 4 * a * c;
      if (det < 0)
        det = 0;
      real lam1 = (-b + sqrt(det)) / (2.0 * a);
      real lam2 = (-b - sqrt(det)) / (2.0 * a);
      if (fabs(lam1) < fabs(lam2)) {
        l[0] = lam1;
      } else {
        l[0] = lam2;
      }
    } else {
      real rucsq[] = {ruc[ICs].norm2()...};
      real rdot[] = {r[ICMs].dot(r[JCMs])...};
      eval_all(l[ICs] = 0 ...);

      real A[] = {2.0 * r[ICMs].dot(ruc[JCMs]) * mass_factor<T, ICMs, JCMs>(invmass)...};
      real Ainv[T::ncons * T::ncons];
      matinv<T::ncons>(Ainv, A);
      real qmat[] = {
          qmat_factor<T, IJKCs>(invmass, rdot)...};
      int done = 0;
      for (int i = 0; i < maxiter; i++) {
        real lmat[] = {l[ICMs] * l[JCMs]...};
        real quad[] = {tdot<T::ncons * T::ncons>(qmat + ICs * T::ncons * T::ncons, lmat)...};
        real b[] = {r0sq[ICs] - rucsq[ICs] - quad[ICs]...};
        real lnew[] = {tdot<T::ncons>(Ainv + ICs * T::ncons, b)...};
        int almost_done = treduce::sum(fabs(lnew[ICs] - l[ICs]) < tol...) == T::ncons;
        eval_all(l[ICs] = lnew[ICs]...);
        if (done)
          break;
        done = almost_done;
      }
    }
    // printf("%f %f %f\n", l[0], l[1], l[2]);
    eval_all(xp[Is] += r[ICs] * l[ICs] * invmass[Is]...);
    eval_all(xp[Js] -= r[ICs] * l[ICs] * invmass[Js]...);
    eval_all(oxp[idx[IAs]] = xp[IAs]...);
  }

  template <class T, int... IAs, int... Is, int... Js, int... ICs, int... ICMs, int... JCMs, int... IJKCs>
  __always_inline void do_shake(vec<real> *f, real rdtfsq, vec<real> *ox, vec<real> *oxp, real *oinvmass, rigid_rec &rec,
                                int *idx,
                                utils::iseq<int, IAs...> iaseq,
                                utils::iseq<int, Is...> iseq,
                                utils::iseq<int, Js...> jseq,
                                utils::iseq<int, ICs...> icseq,
                                utils::iseq<int, ICMs...> icmatseq,
                                utils::iseq<int, JCMs...> jcmatseq,
                                utils::iseq<int, IJKCs...> ijkcseq) {
    vec<real> x[] = {ox[idx[IAs]]...};
    vec<real> xp[] = {oxp[idx[IAs]]...};
    real invmass[] = {oinvmass[idx[IAs]]...};
    vec<real> r[] = {x[Is] - x[Js]...};
    vec<real> ruc[] = {xp[Is] - xp[Js]...};
    real r0sq[] = {rec.r0[ICs] * rec.r0[ICs]...};

    real l[T::ncons];
    if (T::ncons == 1) {
      constexpr int i = first<Is...>();
      constexpr int j = first<Js...>();
      real a = (invmass[i] + invmass[j]) * (invmass[i] + invmass[j]) * r[0].norm2();
      real b = 2.0 * (invmass[i] + invmass[j]) * r[0].dot(ruc[0]);
      real c = ruc[0].norm2() - r0sq[0];
      real det = b * b - 4 * a * c;
      if (det < 0)
        det = 0;
      real lam1 = (-b + sqrt(det)) / (2.0 * a);
      real lam2 = (-b - sqrt(det)) / (2.0 * a);
      if (fabs(lam1) < fabs(lam2)) {
        l[0] = lam1;
      } else {
        l[0] = lam2;
      }
    } else {
      real rucsq[] = {ruc[ICs].norm2()...};
      real rdot[] = {r[ICMs].dot(r[JCMs])...};
      eval_all(l[ICs] = 0 ...);

      real A[] = {2.0 * r[ICMs].dot(ruc[JCMs]) * mass_factor<T, ICMs, JCMs>(invmass)...};
      real Ainv[T::ncons * T::ncons];
      matinv<T::ncons>(Ainv, A);
      real qmat[] = {
          qmat_factor<T, IJKCs>(invmass, rdot)...};
      int done = 0;
      for (int i = 0; i < maxiter; i++) {
        real lmat[] = {l[ICMs] * l[JCMs]...};
        real quad[] = {tdot<T::ncons * T::ncons>(qmat + ICs * T::ncons * T::ncons, lmat)...};
        real b[] = {r0sq[ICs] - rucsq[ICs] - quad[ICs]...};
        real lnew[] = {tdot<T::ncons>(Ainv + ICs * T::ncons, b)...};
        int almost_done = treduce::sum(fabs(lnew[ICs] - l[ICs]) < tol...) == T::ncons;
        eval_all(l[ICs] = lnew[ICs]...);
        if (done)
          break;
        done = almost_done;
      }
    }

    eval_all(f[idx[Is]] += r[ICs] * l[ICs] * rdtfsq...);
    eval_all(f[idx[Js]] -= r[ICs] * l[ICs] * rdtfsq...);
  }

  template <class T>
  void run(vec<real> *f, real rdtfsq, vec<real> *x, vec<real> *xp, real *invmass, rigid_rec &rec, int *idx) {
    do_shake<T>(f, rdtfsq, x, xp, invmass, rec, idx, T::iaseq, T::iseq, T::jseq, T::icseq, T::icmatseq, T::jcmatseq, T::ic3seq);
  }

  template <class T>
  void run_static(vec<real> *x, vec<real> *xp, real *invmass, rigid_rec &rec, int *idx) {
    do_shake_static<T>(x, xp, invmass, rec, idx, T::iaseq, T::iseq, T::jseq, T::icseq, T::icmatseq, T::jcmatseq, T::ic3seq);
  }
};

void shake_setup(cellgrid_t *grid, mpp_t *mpp, real dt, real ftm2v);
void shake_post_force(cellgrid_t *grid, mpp_t *mpp, real dt, real ftm2v);
